#script used to perform PCA - Figure 4 supplement 2
import MDAnalysis as mda
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from MDAnalysis.analysis import align

# Load your simulation data (replace with your file names)
dcd_files = [
    '/beagle3/wtang/ACE_MD/ACE/run3/ACErun3_Feb012024_unwrapped.dcd',
    '/beagle3/wtang/ACE_MD/ACE/run7/ACErun7_Feb012024_unwrapped.dcd',
    '/beagle3/wtang/ACE_MD/ACE/run8/ACErun8_Feb012024.dcd',
    '/beagle3/wtang/ACE_MD/ACE/run9/ACErun9_Feb012024_unwrapped.dcd'
]

# Load the reference structure (dimeric protein)
reference = mda.Universe('/beagle3/wtang/ACE_MD/ACE/ACE_protein_Zn.pdb')

# Define atom selection (CA atoms of both subunits)
atom_selection = "protein and name CA"

# Align each DCD file to the reference structure
aligned_trajectories = []

for dcd_file in dcd_files:
    # Load the universe containing the trajectory
    u_trajectory = mda.Universe('/beagle3/wtang/ACE_MD/ACE/ACE_protein_Zn.pdb', dcd_file)
    
    # Perform alignment to the reference structure
    align.AlignTraj(u_trajectory, reference, select=atom_selection, in_memory=True).run()
    aligned_trajectories.append(u_trajectory)
    print(f'Aligned {dcd_file}')

# Extract coordinates from the aligned simulation data
sim_coordinates = []
labels = []

for idx, aligned_universe in enumerate(aligned_trajectories):
    atoms = aligned_universe.select_atoms(atom_selection)  # Ensure correct atom selection here
    for ts in aligned_universe.trajectory:
        sim_coordinates.append(atoms.positions.flatten())
        labels.append(idx)  # Append the index of the DCD file as a label

sim_coordinates = np.array(sim_coordinates)
labels = np.array(labels)
print('Simulation coordinates extracted')

# Load PDB structures for comparison, make sure these are aligned to the simulation data
pdb_files = [
    '/beagle3/wtang/ACE_MD/ACE/Analysis/PCA/365_aligned_apo.pdb',
    '/beagle3/wtang/ACE_MD/ACE/Analysis/PCA/315_aligned_apo.pdb',
    '/beagle3/wtang/ACE_MD/ACE/Analysis/PCA/305_aligned_apo.pdb',
    '/beagle3/wtang/ACE_MD/ACE/Analysis/PCA/299_aligned_apo.pdb'
]

# Extract coordinates from PDB structures
pdb_coordinates = []

for pdb_file in pdb_files:
    u_pdb = mda.Universe(pdb_file)
    atoms_pdb = u_pdb.select_atoms(atom_selection)
    pdb_coordinates.append(atoms_pdb.positions.flatten())

pdb_coordinates = np.array(pdb_coordinates)
print('PDB coordinates extracted')

# Combine simulation and PDB coordinates
all_coordinates = np.vstack([sim_coordinates, pdb_coordinates])

# Standardize the data
scaler = StandardScaler()
all_coordinates_standardized = scaler.fit_transform(all_coordinates)
print('Data standardized')

# Perform PCA
pca = PCA(n_components=2)
principal_components = pca.fit_transform(all_coordinates_standardized)
print('PCA performed')

# Extract principal components for plotting
sim_principal_components = principal_components[:len(sim_coordinates)]
pdb_principal_components = principal_components[len(sim_coordinates):]
print('Ready to plot')

# Plot the results
plt.figure()
scatter = plt.scatter(
    sim_principal_components[:, 0], sim_principal_components[:, 1], c=labels, cmap='viridis', label='Simulation data'
)
plt.scatter(pdb_principal_components[:, 0], pdb_principal_components[:, 1], label='PDB structures', marker='x')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.title('PCA of Simulation Data with PDB Comparison')

# Create a legend for the DCD files
handles, _ = scatter.legend_elements()
legend_labels = [f'Run {i+1}' for i in range(len(dcd_files))]
plt.legend(handles, legend_labels + ['PDB structures'])

plt.savefig('Apo_MD_vs_ACE_structures_colored.png', dpi=300, bbox_inches='tight')
plt.show()

print('All done')
